import os
import time
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score, average_precision_score
import numpy as np
import sys
import pickle
import networkx as nx
from model import GAE

num_type = 91
num_event_rela_types = 1
num_event_entity_rela_types = 85
num_entity_rela_types = 46
num_rela_types = (num_event_rela_types + num_event_entity_rela_types +
    num_entity_rela_types)

def get_dataset(data_type):
    [all_A_init, all_A_true, all_x_features] = pickle.load(
        open('./data/' + data_type + "_pruned_with_bert_max_50_set.pkl", "rb"))
    return all_A_true, all_A_init, all_x_features

def _convert_A_to_one_hot(A_sparse):
    A_sparse = A_sparse.view(A_sparse.shape[1], A_sparse.shape[2], 1).long()
    A_dense = torch.zeros(A_sparse.shape[0], A_sparse.shape[1],
        num_rela_types + 2)
    # for ii in range(len(A_sparse)):
    #     for jj in range(len(A_sparse[ii])):
    #         A_dense[ii][jj][A_sparse[ii][jj]] = 1
    A_dense.scatter_(2, A_sparse, 1)
    return A_dense

def _convert_X_to_one_hot(x_feature_sparse):
    # x_feature_sparse = x_feature_sparse.view(x_feature_sparse.shape[1], 1).long()
    # x_feature_dense = torch.zeros(x_feature_sparse.shape[0], num_type)
    # x_feature_dense.scatter_(1, x_feature_sparse, 1)
    # return x_feature_dense
    # print(x_feature_sparse.shape)
    return x_feature_sparse[0]


def _convert_A_to_input(A_sparse):
    return (A_sparse[0] != 1).float()
    # A_ret = torch.zeros_like(A_sparse[0])
    # for ii in range(len(A_sparse[0])):
    #     for jj in range(len(A_sparse[0][ii])):
    #         if A_sparse[0][ii][jj] != 1:
    #             A_ret



class GAE_dataset(Dataset):
    def __init__(self, features, adj_true, adj_init):
        self.features = features
        self.adj_true = adj_true
        self.adj_init = adj_init

    def __len__(self):
        return len(self.features)

    def __getitem__(self, index):
        return self.features[index], self.adj_true[index], self.adj_init[index]

train_adj_true_all, train_adj_init_all, train_features_all = get_dataset(
    "train")
dev_adj_true_all, dev_adj_init_all, dev_features_all = get_dataset(
    "dev")
test_adj_true_all, test_adj_init_all, test_features_all = get_dataset(
    "test")

train_set = GAE_dataset(
    train_features_all, train_adj_true_all, train_adj_init_all)
dev_set = GAE_dataset(
    dev_features_all, dev_adj_true_all, dev_adj_init_all)
test_set = GAE_dataset(
    test_features_all, test_adj_true_all, test_adj_init_all)

train_loader = DataLoader(train_set, batch_size=1, shuffle=True)
dev_loader = DataLoader(dev_set, batch_size=1, shuffle=False)
test_loader = DataLoader(test_set, batch_size=1, shuffle=False)




def get_acc(adj_pred, adj_true):
    adj_pred = torch.argmax(adj_pred, dim=1).cpu().numpy()
    adj_true = adj_true.view(-1).cpu().numpy()
    considered_true = []
    considered_pred = []
    for ii in range(len(adj_true)):
        if adj_true[ii] != 1:
            considered_true.append(adj_true[ii])
            considered_pred.append(adj_pred[ii])
    if len(considered_true) == 0:
        print("here")
    considered_true = np.array(considered_true)
    considered_pred = np.array(considered_pred)
    acc_filtered = (considered_pred == considered_true).sum() / len(
        considered_true)
    acc_all = (adj_pred == adj_true).sum() / len(adj_true)
    return acc_filtered, acc_all


device = torch.device("cuda:2")
# input_dim = num_type
# input_dim = 768
input_dim = 91
hidden1_dim = 128
hidden2_dim = 64

model = GAE(input_dim, hidden1_dim, hidden2_dim, num_rela_types)
model.to(device)
# model = model.float()


model.train()
optimizer = AdamW(model.parameters(), lr=1e-5)

best_test = 0.
for epoch in range(800):
    all_loss = 0.
    all_train_acc_filtered, all_train_acc_all = 0., 0.
    model.train()
    for ii, (features, adj_true, adj_init) in enumerate(train_loader):
        features = _convert_X_to_one_hot(features).float().to(device)
        adj_init = _convert_A_to_input(adj_init).to(device)

        index_mat = torch.tensor(range(len(features)))
        index_mat = torch.cartesian_prod(index_mat, index_mat).to(device)

        adj_true = adj_true.long().to(device)
        adj_pred = model(features, adj_init, index_mat)
        optimizer.zero_grad()

        norm = len(adj_pred) / float((len(adj_pred) - (adj_true != 1).sum()) * 2)

        # adj_pred (num_nodes, num_nodes, num_edge_types)
        # adj_true (num_nodes, num_nodes)
        # print(adj_pred.shape)
        
        for mm in range(132):
            filled_val = torch.zeros_like(adj_pred[0])
            filled_val[mm] = 1e10
#             if mm == 2:
#                 continue
            if mm >= 88 or mm == 2:
                adj_pred[(adj_true.view(-1) == mm).nonzero(as_tuple=False)[:-1]] = filled_val
            else:
                adj_pred[(adj_true.view(-1) == mm).nonzero(as_tuple=False)] = filled_val
        filled_val = torch.zeros_like(adj_pred[0])
        filled_val[1] = 1e10
        adj_pred[(adj_true.view(-1) == 1).nonzero(as_tuple=False)] = filled_val
#         print(adj_pred)
#         raise
#         adj_pred[(adj_true.view(-1) == 1).nonzero(as_tuple=False)] = filled_val
        
        # loss = F.cross_entropy(adj_pred, adj_true.view(-1)) / len(adj_pred)
        loss = F.cross_entropy(adj_pred, adj_true.view(-1))
        loss.backward()
        optimizer.step()
        all_loss += loss.cpu().item()
        train_acc_filtered, train_acc_all = get_acc(adj_pred, adj_true)
        all_train_acc_filtered += train_acc_filtered
        all_train_acc_all += train_acc_all
        # if ii % 20 == 0:
        #     print("curr loss: " + str(loss))
    all_train_acc_filtered /= len(train_loader)
    all_train_acc_all /= len(train_loader)
    print("epoch " + str(epoch) + " loss: " + str(all_loss) + " train acc filtered: " + str(round(all_train_acc_filtered, 4)) + " train acc all: " + str(round(all_train_acc_all, 4)))
    all_test_acc_filtered, all_test_acc_all = 0., 0.
    model.eval()
    for ii, (features, adj_true, adj_init) in enumerate(test_loader):
        features = _convert_X_to_one_hot(features).float().to(device)
        adj_init = _convert_A_to_input(adj_init).to(device)
        index_mat = torch.tensor(range(len(features)))
        index_mat = torch.cartesian_prod(index_mat, index_mat).to(device)
        adj_true = adj_true.long().to(device)
        adj_pred = model(features, adj_init, index_mat)
        test_acc_filtered, test_acc_all = get_acc(adj_pred, adj_true)
        all_test_acc_filtered += test_acc_filtered
        all_test_acc_all += test_acc_all
    all_test_acc_filtered /= len(test_loader)
    all_test_acc_all /= len(test_loader)
    if all_test_acc_filtered > best_test:
        best_test = all_test_acc_filtered
        torch.save(model.state_dict(), "best_model.pt")
    print("epoch " + str(epoch) + " test acc filtered: " + str(round(all_test_acc_filtered, 4)) + " test acc all: " + str(round(all_test_acc_all, 4)))




    
    








































